import taichi as ti
from ..camera import Camera
from ..utils import *


class Spherical(Camera):
    """球极全景相机"""

    def __init__(self, size: "tuple[int, int]"):
        super().__init__(size)
        self.xyz: ti.MatrixField = Mat3x3.field(shape=1)
        self.setScreen()

    def setScreen(self):
        sinTheta = ti.sin(self.theta)
        z = Vec3(
            ti.cos(self.phi) * sinTheta, ti.cos(self.theta), ti.sin(self.phi) * sinTheta
        )
        x = Vec3(ti.cos(self.phi - PI / 2), 0, ti.sin(self.phi - PI / 2))
        y = z.cross(x)
        self.x, self.y, self.z = Vec4(x, 0), Vec4(y, 0), Vec4(z, 0)
        self.xyz[0] = Mat3x3(x, y, z).transpose()
        return self

    @ti.func
    def rayAt(self, uv: Vec2, t: int = 0) -> Ray4:
        ss = superSampling(self.SSNumber, t)
        halfWH = Vec2(self.w, self.h)
        shift = (uv + ss) / halfWH * 2 - 1
        theta, phi = (1 + shift.y) * PI / 2, PI * shift.x
        locXYZ = Vec3(
            ti.sin(theta) * ti.cos(phi), ti.sin(theta) * ti.sin(phi), ti.cos(theta)
        )
        return Ray4(self.origin[0], Vec4((self.xyz[0] @ locXYZ).normalized(), -1))
